from ._forward import ForwardPasser
from ._pruning import PruningPasser, FEAT_IMP_CRITERIA
from ._util import ascii_table, apply_weights_2d, gcv
from ._types import BOOL
from sklearn.base import RegressorMixin, BaseEstimator, TransformerMixin
from sklearn.utils.validation import (assert_all_finite, check_is_fitted,
                                      check_X_y)
import numpy as np
from scipy import sparse
from ._version import get_versions
__version__ = get_versions()['version']

class Earth(BaseEstimator, RegressorMixin, TransformerMixin):

    """
    Multivariate Adaptive Regression Splines

    A flexible regression method that automatically searches for interactions
    and non-linear relationships.  Earth models can be thought of as
    linear models in a higher dimensional basis space
    (specifically, a multivariate truncated power spline basis).
    Each term in an Earth model is a product of so called "hinge functions".
    A hinge function is a function that's equal to its argument where that
    argument is greater than zero and is zero everywhere else.

    The multivariate adaptive regression splines algorithm has two stages.
    First, the forward pass searches for terms in the truncated power spline
    basis that locally minimize the squared error loss of the training set.
    Next, a pruning pass selects a subset of those terms that produces
    a locally minimal generalized cross-validation (GCV) score.  The GCV score
    is not actually based on cross-validation, but rather is meant to
    approximate a true cross-validation score by penalizing model complexity.
    The final result is a set of terms that is nonlinear in the original
    feature space, may include interactions, and is likely to generalize well.

    The Earth class supports dense input only.  Data structures from the
    pandas and patsy modules are supported, but are copied into numpy arrays
    for computation.  No copy is made if the inputs are numpy float64 arrays.
    Earth objects can be serialized using the pickle module and copied
    using the copy module.


    Parameters
    ----------
    max_terms : int, optional (default=min(2 * n + m // 10, 400)), 
                               where n is the number of features and m is the number 
                               of rows)
        The maximum number of terms generated by the forward pass.  All memory is
        allocated at the beginning of the forward pass, so setting max_terms to 
        a very high number on a system with insufficient memory may cause a 
        MemoryError at the start of the forward pass.


    max_degree : int, optional (default=1)
        The maximum degree of terms generated by the forward pass.


    allow_missing : boolean, optional (default=False)
        If True, use missing data method described in [3].
        Use missing argument to determine missingness or,if X is a pandas
        DataFrame, infer missingness from X.


    penalty : float, optional (default=3.0)
        A smoothing parameter used to calculate GCV and GRSQ.
        Used during the pruning pass and to determine whether to add a hinge
        or linear basis function during the forward pass.
        See the d parameter in equation 32, Friedman, 1991.


    endspan_alpha : float, optional, probability between 0 and 1 (default=0.05)
        A parameter controlling the calculation of the endspan
        parameter (below).  The endspan parameter is calculated as
        round(3 - log2(endspan_alpha/n)), where n is the number of features.
        The endspan_alpha parameter represents the probability of a run of
        positive or negative error values on either end of the data vector
        of any feature in the data set.  See equation 45, Friedman, 1991.


    endspan : int, optional (default=-1)
        The number of extreme data values of each feature not eligible
        as knot locations. If endspan is set to -1 (default) then the
        endspan parameter is calculated based on endspan_alpah (above).
        If endspan is set to a positive integer then endspan_alpha is ignored.


    minspan_alpha : float, optional, probability between 0 and 1 (default=0.05)
        A parameter controlling the calculation of the minspan
        parameter (below).  The minspan parameter is calculated as

            (int) -log2(-(1.0/(n*count))*log(1.0-minspan_alpha)) / 2.5

        where n is the number of features and count is the number of points at
        which the parent term is non-zero.  The minspan_alpha parameter
        represents the probability of a run of positive or negative error
        values between adjacent knots separated by minspan intervening
        data points. See equation 43, Friedman, 1991.


    minspan : int, optional (default=-1)
        The minimal number of data points between knots.  If minspan is set
        to -1 (default) then the minspan parameter is calculated based on
        minspan_alpha (above).  If minspan is set to a positive integer then
        minspan_alpha is ignored.


    thresh : float, optional (default=0.001)
        Parameter used when evaluating stopping conditions for the forward
        pass. If either RSQ > 1 - thresh or if RSQ increases by less than
        thresh for a forward pass iteration then the forward pass is
        terminated.


    zero_tol : float, optional (default=1e-12)
        Used when determining whether a floating point number is zero during
        the  forward pass.  This is important in determining linear dependence
        and in the fast update procedure.  There should normally be no reason
        to change  zero_tol from its default. However, if nans are showing up
        during the forward pass or the forward pass seems to be terminating
        unexpectedly, consider adjusting zero_tol.

    min_search_points : int, optional (default=100)
        Used to calculate check_every (below).  The minimum samples necessary
        for check_every to be greater than 1.  The check_every parameter
        is calculated as

             (int) m / min_search_points

        if m > min_search_points, where m is the number of samples in the
        training set.  If m <= min_search_points then check_every is set to 1.


    check_every : int, optional (default=-1)
        If check_every > 0, only one of every check_every sorted data points
        is considered as a candidate knot.  If check_every is set to -1 then
        the check_every parameter is calculated based on
        min_search_points (above).


    allow_linear : bool, optional (default=True)
        If True, the forward pass will check the GCV of each new pair of terms
        and, if it's not an improvement on a single term with no knot (called a
        linear term, although it may actually be a product of a linear term
        with some other parent term), then only that single, knotless term will
        be used. If False, that behavior is disabled and all terms will have
        knots except those with variables specified by the linvars argument
        (see the fit method).

    use_fast : bool, optional (default=False)
        if True, use the approximation procedure defined in [2] to speed up the
        forward pass. The procedure uses two hyper-parameters : fast_K
        and fast_h. Check below for more details.

    fast_K : int, optional (default=5)
        Only used if use_fast is True. As defined in [2], section 3.0, it
        defines the maximum number of basis functions to look at when
        we search for a parent, that is we look at only the fast_K top
        terms ranked by the mean squared error of the model the last time
        the term was chosen as a parent. The smaller fast_K is, the more
        gains in speed we get but the more approximate is the result.
        If fast_K is the maximum number of terms and fast_h is 1,
        the behavior is the same as in the normal case
        (when use_fast is False).

    fast_h : int, optional (default=1)
        Only used if use_fast is True. As defined in [2], section 4.0, it
        determines the number of iterations before repassing through all
        the variables when searching for the variable to use for a
        given parent term. Before reaching fast_h number of iterations
        only the last chosen variable for the parent term is used. The
        bigger fast_h is, the more speed gains we get, but the result
        is more approximate.

    smooth : bool, optional (default=False)
        If True, the model will be smoothed such that it has continuous first
        derivatives.
        For details, see section 3.7, Friedman, 1991.

    enable_pruning : bool, optional(default=True)
        If False, the pruning pass will be skipped.

    feature_importance_type: string or list of strings, optional (default=None)
        Specify which kind of feature importance criteria to compute.
        Currently three criteria are supported : 'gcv', 'rss' and 'nb_subsets'.
        By default (when it is None), no feature importance is computed.
        Feature importance is a measure of the effect of the features
        on the outputs. For each feature, the values go from
        0 to 1 and sum up to 1. A high value means the feature have in average
        (over the population) a large effect on the outputs.
        See [4], section 12.3 for more information about the criteria.

    verbose : int, optional(default=0)
        If verbose >= 1, print out progress information during fitting.  If
        verbose >= 2, also print out information on numerical difficulties
        if encountered during fitting. If verbose >= 3, print even more
        information that is probably only useful to the developers of py-earth.

    Attributes
    ----------
    `coef_` : array, shape = [pruned basis length, number of outputs]
        The weights of the model terms that have not been pruned.


    `basis_` : _basis.Basis
        An object representing model terms.  Each term is a product of
        constant, linear, and hinge functions of the input features.


    `mse_` : float
        The mean squared error of the model after the final linear fit.
        If sample_weight and/or output_weight are given, this score is
        weighted appropriately.


    `rsq_` : float
        The generalized r^2 of the model after the final linear fit.
        If sample_weight and/or output_weight are given, this score is
        weighted appropriately.

    `gcv_` : float
        The generalized cross validation (GCV) score of the model after the
        final linear fit. If sample_weight and/or output_weight are
        given, this score is weighted appropriately.

    `grsq_` : float
        An r^2 like score based on the GCV. If sample_weight and/or
        output_weight are given, this score is
        weighted appropriately.

    `forward_pass_record_` : _record.ForwardPassRecord
        An object containing information about the forward pass, such as
        training loss function values after each iteration and the final
        stopping condition.


    `pruning_pass_record_` : _record.PruningPassRecord
        An object containing information about the pruning pass, such as
        training loss function values after each iteration and the
        selected optimal iteration.


    `xlabels_` : list
        List of column names for training predictors.
        Defaults to ['x0','x1',....] if column names are not provided.


    `allow_missing_` : list
        List of booleans indicating whether each variable is allowed to
        be missing.  Set during training.  A variable may be missing
        only if fitting included missing data for that variable.

    `feature_importances_`: array of shape [m] or dict
        m is the number of features.
        if one feature importance type is specified, it is an
        array of shape m. If several feature importance types are
        specified, then it is dict where each key is a feature importance type
        name and its corresponding value is an array of shape m.
    
    `_version`: string
        The version of py-earth in which the Earth object was originally 
        created.  This information may be useful when dealing with 
        serialized Earth objects.


    References
    ----------

    .. [1] Friedman, Jerome. Multivariate Adaptive Regression Splines.
           Annals of Statistics. Volume 19, Number 1 (1991), 1-67.

    .. [2] Fast MARS, Jerome H.Friedman, Technical Report No.110, May 1993.

    .. [3] Estimating Functions of Mixed Ordinal and Categorical Variables
           Using Adaptive Splines, Jerome H.Friedman, Technical Report
           No.108, June 1991.

    .. [4] http://www.milbo.org/doc/earth-notes.pdf

    """

    forward_pass_arg_names = set([
        'max_terms', 'max_degree', 'allow_missing', 'penalty',
        'endspan_alpha', 'endspan',
        'minspan_alpha', 'minspan',
        'thresh', 'zero_tol', 'min_search_points',
        'check_every', 'allow_linear',
        'use_fast', 'fast_K', 'fast_h',
        'feature_importance_type',
        'verbose'
    ])
    pruning_pass_arg_names = set([
        'penalty',
        'feature_importance_type',
        'verbose'
    ])

    def __init__(self, max_terms=None, max_degree=None, allow_missing=False,
                 penalty=None, endspan_alpha=None, endspan=None,
                 minspan_alpha=None, minspan=None,
                 thresh=None, zero_tol=None, min_search_points=None,
                 check_every=None,
                 allow_linear=None, use_fast=None, fast_K=None,
                 fast_h=None, smooth=None, enable_pruning=True,
                 feature_importance_type=None, verbose=0):

        self.max_terms = max_terms
        self.max_degree = max_degree
        self.allow_missing = allow_missing
        self.penalty = penalty
        self.endspan_alpha = endspan_alpha
        self.endspan = endspan
        self.minspan_alpha = minspan_alpha
        self.minspan = minspan
        self.thresh = thresh
        self.zero_tol = zero_tol
        self.min_search_points = min_search_points
        self.check_every = check_every
        self.allow_linear = allow_linear
        self.use_fast = use_fast
        self.fast_K = fast_K
        self.fast_h = fast_h
        self.smooth = smooth
        self.enable_pruning = enable_pruning
        self.feature_importance_type = feature_importance_type
        self.verbose = verbose
        self._version = __version__

    def __eq__(self, other):
        if self.__class__ is not other.__class__:
            return False
        keys = set(self.__dict__.keys()) | set(other.__dict__.keys())
        for k in keys:
            try:
                v_self = self.__dict__[k]
                v_other = other.__dict__[k]
            except KeyError:
                return False
            try:
                if v_self != v_other:
                    return False
            except ValueError:  # Case of numpy arrays
                if np.any(v_self != v_other):
                    return False
        return True

    def __ne__(self, other):
        return not self.__eq__(other)

    def _pull_forward_args(self, **kwargs):
        '''
        Pull named arguments relevant to the forward pass.
        '''
        result = {}
        for name in self.forward_pass_arg_names:
            if name in kwargs and kwargs[name] is not None:
                result[name] = kwargs[name]
        return result

    def _pull_pruning_args(self, **kwargs):
        '''
        Pull named arguments relevant to the pruning pass.
        '''
        result = {}
        for name in self.pruning_pass_arg_names:
            if name in kwargs and kwargs[name] is not None:
                result[name] = kwargs[name]
        return result

    def _scrape_labels(self, X):
        '''
        Try to get labels from input data (for example, if X is a
        pandas DataFrame).  Return None if no labels can be extracted.
        '''
        try:
            labels = list(X.columns)
        except AttributeError:
            try:
                labels = list(X.design_info.column_names)
            except AttributeError:
                try:
                    labels = list(X.dtype.names)
                except TypeError:
                    try:
                        labels = ['x%d' % i for i in range(X.shape[1])]
                    except IndexError:
                        labels = ['x%d' % i for i in range(1)]
                # handle case where X is not np.array (e.g list)
                except AttributeError:
                    X = np.array(X)
                    labels = ['x%d' % i for i in range(X.shape[1])]
        return labels

    def _scrub_x(self, X, missing, **kwargs):
        '''
        Sanitize input predictors and extract column names if appropriate.
        '''
        # Check for sparseness
        if sparse.issparse(X):
            raise TypeError('A sparse matrix was passed, but dense data '
                            'is required. Use X.toarray() to convert to '
                            'dense.')
        X = np.asarray(X, dtype=np.float64, order='F')
        
        # Figure out missingness
        missing_is_nan = False
        if missing is None:
            # Infer missingness
            missing = np.isnan(X)
            missing_is_nan = True
            
        if not self.allow_missing:
            try:
                assert_all_finite(X)
            except ValueError:
                raise ValueError(
                    "Input contains NaN, infinity or a value that's too large."
                    "Did you mean to set allow_missing=True?")
        if X.ndim == 1:
            X = X[:, np.newaxis]

        # Ensure correct number of columns
        if hasattr(self, 'basis_') and self.basis_ is not None:
            if X.shape[1] != self.basis_.num_variables:
                raise ValueError('Wrong number of columns in X. Reshape your data.')
        
        # Zero-out any missing spots in X
        if np.any(missing):
            if not self.allow_missing:
                raise ValueError('Missing data requires allow_missing=True.')
            if missing_is_nan or np.any(np.isnan(X)):
                X = X.copy()
                X[missing] = 0.
        
        # Convert to internally used data type
        missing = np.asarray(missing, dtype=BOOL, order='F')
        assert_all_finite(missing)
        if missing.ndim == 1:
            missing = missing[:, np.newaxis]
        
        return X, missing

    def _scrub(self, X, y, sample_weight, output_weight, missing, **kwargs):
        '''
        Sanitize input data.
        '''
        # Check for sparseness
        if sparse.issparse(y):
            raise TypeError('A sparse matrix was passed, but dense data '
                            'is required. Use y.toarray() to convert to '
                            'dense.')
        if sparse.issparse(sample_weight):
            raise TypeError('A sparse matrix was passed, but dense data '
                            'is required. Use sample_weight.toarray()'
                            'to convert to dense.')
        if sparse.issparse(output_weight):
            raise TypeError('A sparse matrix was passed, but dense data '
                            'is required. Use output_weight.toarray()'
                            'to convert to dense.')

        # Check whether X is the output of patsy.dmatrices
        if y is None and isinstance(X, tuple):
            y, X = X

        # Handle X separately
        X, missing = self._scrub_x(X, missing, **kwargs)

        # Convert y to internally used data type
        y = np.asarray(y, dtype=np.float64)
        assert_all_finite(y)

        if len(y.shape) == 1:
            y = y[:, np.newaxis]

        # Deal with sample_weight
        if sample_weight is None:
            sample_weight = np.ones((y.shape[0], 1), dtype=y.dtype)
        else:
            sample_weight = np.asarray(sample_weight, dtype=np.float64)
            assert_all_finite(sample_weight)
            if len(sample_weight.shape) == 1:
                sample_weight = sample_weight[:, np.newaxis]
        # Deal with output_weight
        if output_weight is not None:
            output_weight = np.asarray(output_weight, dtype=np.float64)
            assert_all_finite(output_weight)

        # Make sure dimensions match
        if y.shape[0] != X.shape[0]:
            raise ValueError('X and y do not have compatible dimensions.')
        if y.shape[0] != sample_weight.shape[0]:
            raise ValueError(
                'y and sample_weight do not have compatible dimensions.')
        if output_weight is not None and y.shape[1] != output_weight.shape[0]:
            raise ValueError(
                'y and output_weight do not have compatible dimensions.')
        if y.shape[1] > 1:
            if sample_weight.shape[1] == 1 and output_weight is not None:
                sample_weight = np.repeat(sample_weight, y.shape[1], axis=1)
        if output_weight is not None:
            sample_weight *= output_weight

        # Make sure everything is finite (except X, which is allowed to have
        # missing values)
        assert_all_finite(missing)
        assert_all_finite(y)
        assert_all_finite(sample_weight)
        assert_all_finite(output_weight)

        # Make sure everything is consistent
        check_X_y(X, y, accept_sparse=False, multi_output=True,
                  force_all_finite=False)

        return X, y, sample_weight, None, missing

    def fit(self, X, y=None,
            sample_weight=None,
            output_weight=None,
            missing=None,
            xlabels=None, linvars=[]):
        '''
        Fit an Earth model to the input data X and y.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features the training predictors.
            The X parameter can be a numpy array, a pandas DataFrame, a patsy
            DesignMatrix, or a tuple of patsy DesignMatrix objects as
            output by patsy.dmatrices.


        y : array-like, optional (default=None), shape = [m, p] where m is the
            number of samples The training response, p the number of outputs.
            The y parameter can be a numpy array, a pandas DataFrame,
            a Patsy DesignMatrix, or can be left as None (default) if X was
            the output of a call to patsy.dmatrices (in which case, X contains
            the response).


        sample_weight : array-like, optional (default=None), shape = [m]
             where m is the number of samples.
             Sample weights for training.  Weights must be greater than or
             equal to zero. Rows with zero weight do not contribute at all.
             Weights are useful when dealing with heteroscedasticity.
             In such cases, the weight should be proportional to the inverse of
             the (known) variance.

        output_weight : array-like, optional (default=None), shape = [p]
             where p is the number of outputs.
             Output weights for training. Weights must be greater than or equal
             to zero. Output with zero weight do not contribute at all.

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a  patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        linvars : iterable of strings or ints, optional (empty by default)
            Used to specify features that may only enter terms as linear basis
            functions (without knots).  Can include both column numbers and
            column names (see xlabels, below).  If left empty, some variables
            may still enter linearly during the forward pass if no knot would
            provide a reduction in GCV compared to the linear function.
            Note that this feature differs from the R package earth.


        xlabels : iterable of strings, optional (empty by default)
            The xlabels argument can be used to assign names to data columns.
            This argument is not generally needed, as names can be captured
            automatically from most standard data structures.
            If included, must have length n, where n is the number of features.
            Note that column order is used to compute term values and make
            predictions, not column names.


        '''
        # Format and label the data
        if xlabels is None:
            self.xlabels_ = self._scrape_labels(X)
        else:
            if len(xlabels) != X.shape[1]:
                raise ValueError('The length of xlabels is not the '
                                 'same as the number of columns of X')
            self.xlabels_ = xlabels
        if self.feature_importance_type is not None:
            feature_importance_type = self.feature_importance_type
            try:
                is_str = isinstance(feature_importance_type, basestring)
            except NameError:
                is_str = isinstance(feature_importance_type, str)
            if is_str:
                feature_importance_type = [feature_importance_type]
            for k in feature_importance_type:
                if k not in FEAT_IMP_CRITERIA:
                    msg = ("'{}' is not valid value for feature_importance, "
                           "allowed critera are : {}".format(k, FEAT_IMP_CRITERIA))
                    raise ValueError(msg)

            if len(feature_importance_type) > 0 and self.enable_pruning is False:
                raise ValueError("Cannot compute feature importances because pruning is disabled,"
                                 "please re-enable pruning by setting enable_pruning to True in order"
                                 "to enable feature importance estimation")

        self.linvars_ = linvars
        X, y, sample_weight, output_weight, missing = self._scrub(
            X, y, sample_weight, output_weight, missing)

        # Do the actual work
        self.forward_pass(X, y,
                          sample_weight, output_weight, missing,
                          self.xlabels_, linvars, skip_scrub=True)
        if self.enable_pruning is True:
            self.pruning_pass(X, y,
                              sample_weight, output_weight, missing,
                              skip_scrub=True)
        if hasattr(self, 'smooth') and self.smooth:
            self.basis_ = self.basis_.smooth(X)
        self.linear_fit(X, y, sample_weight, output_weight, missing,
                        skip_scrub=True)
        return self

#     def forward_pass2(self, X, y=None,
#                        sample_weight=None, output_weight=None,
#                        missing=None,
#                        xlabels=None, linvars=[]):
#         # Label and format data
#         if xlabels is None:
#             self.xlabels_ = self._scrape_labels(X)
#         else:
#             self.xlabels_ = xlabels
#         X, y, sample_weight, output_weight, missing = self._scrub(
#             X, y, sample_weight, output_weight, missing)
#
#         # Do the actual work
#         args = self._pull_forward_args(**self.__dict__)
#         forward_passer = ForwardPasser(
#             X, missing, y, sample_weight, output_weight,
#             xlabels=self.xlabels_, linvars=linvars, **args)
# #         forward_passer.run()
#         linvars_ = []
#         self.forward_pass_record_, self.basis_ = forward_pass(X, missing, y,
#             sample_weight, output_weight, xlabels=self.xlabels_, linvars=linvars)
# #         self.basis_ = forward_passer.get_basis()
#
    def forward_pass(self, X, y=None,
                     sample_weight=None, output_weight=None,
                     missing=None,
                     xlabels=None, linvars=[], skip_scrub=False):
        '''
        Perform the forward pass of the multivariate adaptive regression
        splines algorithm.  Users will normally want to call the fit method
        instead, which performs the forward pass, the pruning pass,
        and a linear fit to determine the final model coefficients.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples and n
            is the number of features The training predictors.
            The X parameter can be a numpy array, a pandas DataFrame, a patsy
            DesignMatrix, or a tuple of patsy DesignMatrix objects as output
            by patsy.dmatrices.

        y : array-like, optional (default=None), shape = [m, p] where m is the
            number of samples, p the number of outputs.
            The y parameter can be a numpy array, a pandas DataFrame,
            a Patsy DesignMatrix, or can be left as None (default) if X was
            the output of a call to patsy.dmatrices (in which case, X contains
            the response).

        sample_weight : array-like, optional (default=None), shape = [m]
             where m is the number of samples.
             Sample weights for training.  Weights must be greater than or
             equal to zero. Rows with zero weight do not contribute at all.
             Weights are useful when dealing with heteroscedasticity.  In such
             cases, the weight should be proportional to the inverse of the
             (known) variance.

        output_weight : array-like, optional (default=None), shape = [p]
             where p is the number of outputs.
             The total mean squared error (MSE) is a weighted sum of
             mean squared errors (MSE) associated to each output, where
             the weights are given by output_weight.
             Output weights must be greater than or equal
             to zero. Outputs with zero weight do not contribute at all
             to the total mean squared error (MSE).

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a  patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        linvars : iterable of strings or ints, optional (empty by default)
            Used to specify features that may only enter terms as linear basis
            functions (without knots).  Can include both column numbers and
            column names (see xlabels, below).


        xlabels : iterable of strings, optional (empty by default)
            The xlabels argument can be used to assign names to data columns.
            This argument is not generally needed, as names can be captured
            automatically from most standard data structures.
            If included, must have length n, where n is the number of features.
            Note that column order is used to compute term values and make
            predictions, not column names.


        '''
        # Label and format data
        if xlabels is None:
            self.xlabels_ = self._scrape_labels(X)
        else:
            self.xlabels_ = xlabels
        if not skip_scrub:
            X, y, sample_weight, output_weight, missing = self._scrub(
                X, y, sample_weight, output_weight, missing)

        # Do the actual work
        args = self._pull_forward_args(**self.__dict__)
        forward_passer = ForwardPasser(
            X, missing, y, sample_weight,
            xlabels=self.xlabels_, linvars=linvars, **args)
        forward_passer.run()
        self.forward_pass_record_ = forward_passer.trace()
        self.basis_ = forward_passer.get_basis()

    def pruning_pass(self, X, y=None, sample_weight=None, output_weight=None,
                     missing=None, skip_scrub=False):
        '''
        Perform the pruning pass of the multivariate adaptive regression
        splines algorithm.  Users will normally want to call the fit
        method instead, which performs the forward pass, the pruning
        pass, and a linear fit to determine the final model coefficients.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features The training predictors.
            The X parameter can be a numpy array, a pandas DataFrame, a patsy
            DesignMatrix, or a tuple of patsy DesignMatrix objects as output
            by patsy.dmatrices.

        y : array-like, optional (default=None), shape = [m, p] where m is the
            number of samples, p the number of outputs.
            The y parameter can be a numpy array, a pandas DataFrame,
            a Patsy DesignMatrix, or can be left as None (default) if X was
            the output of a call to patsy.dmatrices (in which case, X contains
            the response).

        sample_weight : array-like, optional (default=None), shape = [m]
             where m is the number of samples.
             Sample weights for training.  Weights must be greater than or
             equal to zero. Rows with zero weight do not contribute at all.
             Weights are useful when dealing with heteroscedasticity.  In such
             cases, the weight should be proportional to the inverse of the
             (known) variance.

        output_weight : array-like, optional (default=None), shape = [p]
             where p is the number of outputs.
             The total mean squared error (MSE) is a weighted sum of
             mean squared errors (MSE) associated to each output, where
             the weights are given by output_weight.
             Output weights must be greater than or equal
             to zero. Outputs with zero weight do not contribute at all
             to the total mean squared error (MSE).

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a  patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        '''
        # Format data
        if not skip_scrub:
            X, y, sample_weight, output_weight, missing = self._scrub(
                X, y, sample_weight, output_weight, missing)

#         if sample_weight.shape[1] == 1 and y.shape[1] != 1:
#             sample_weight = np.repeat(sample_weight,y.shape[1],axis=1)

        # Pull arguments from self
        args = self._pull_pruning_args(**self.__dict__)

        # Do the actual work
        pruning_passer = PruningPasser(
            self.basis_, X, missing, y, sample_weight,
            **args)
        pruning_passer.run()

        imp = pruning_passer.feature_importance
        self._feature_importances_dict = imp
        if len(imp) == 1: # if only one criterion then return it only
            imp = imp[list(imp.keys())[0]]
        elif len(imp) == 0:
            imp = None
        self.feature_importances_ = imp
        self.pruning_pass_record_ = pruning_passer.trace()


#         # Format data
#         X, y, sample_weight, output_weight, missing = self._scrub(
#             X, y, sample_weight, output_weight, missing)
#
#         # Dimensions
#         m, n = X.shape
#         basis_size = len(self.basis_)
#         pruned_basis_size = self.basis_.plen()
#
#         self.pruning_pass_record_ = PruningPassRecord(m, n, args['penalty'], sst, pruned_basis_size)
#
#
#
#         # Pull arguments from self
#         args = self._pull_pruning_args(**self.__dict__)
#         penalty = args.get('penalty', 3.0)
#
#         # Compute prune sets
#         prune_sets = [[basis_size - i for i in range(n)] for n in range(basis_size)]
#
#         # Create the record object
#         record = PruningPassRecord('gcv', basis_size)
#
#         # Score each prune set to find the best
#         best_prune_set = []
#         best_score = float('inf')
#         for prune_set in prune_sets:
#             print 1
#             print prune_set
#             sys.stdout.flush()
#             # Prune this prune set
#             for idx in prune_set:
#                 self.basis_[idx].prune()
#
#             print 2
#             sys.stdout.flush()
#             # Score this prune set
#             self.linear_fit(X, y, sample_weight, output_weight, missing, skip_scrub=True)
#             y_pred = self.predict(X, missing, skip_scrub=True)
#             score = gcv(np.mean(((y - y_pred) * np.sqrt(sample_weight)) ** 2),
#                         basis_size - len(prune_set), m, penalty)
#             r2 = self.score(X, y, sample_weight, output_weight, missing, skip_scrub=True)
#
#             print 3
#             sys.stdout.flush()
#             # Minimizer
#             if score < best_score:
#                 best_score = score
#                 best_prune_set = prune_set
#
#             print 4
#             sys.stdout.flush()
#             # Unprune for next iteration
#             for idx in prune_set:
#                 self.basis_[idx].unprune()
#
#             print 5
#             sys.stdout.flush()
#             # Add to the record
#             record.add(prune_set, score, r2)
#
#         # Apply the best prune set
#         for idx in best_prune_set:
#             self.basis_[idx].prune()
#
#         self.pruning_pass_record_ = record

    def forward_trace(self):
        '''Return information about the forward pass.'''
        try:
            return self.forward_pass_record_
        except AttributeError:
            return None

    def pruning_trace(self):
        '''Return information about the pruning pass.'''
        try:
            return self.pruning_pass_record_
        except AttributeError:
            return None

    def trace(self):
        '''Return information about the forward and pruning passes.'''
        return EarthTrace(self.forward_trace(), self.pruning_trace())

    def summary(self):
        '''Return a string describing the model.'''
        result = ''
        if self.forward_trace() is None:
            result += 'Untrained Earth Model'
            return result
        elif self.pruning_trace() is None:
            result += 'Unpruned Earth Model\n'
        else:
            result += 'Earth Model\n'
        header = ['Basis Function', 'Pruned']
        if self.coef_.shape[0] > 1:
            header += ['Coefficient %d' %
                       i for i in range(self.coef_.shape[0])]
        else:
            header += ['Coefficient']
        data = []

        i = 0
        for bf in self.basis_:
            data.append([str(bf), 'Yes' if bf.is_pruned() else 'No'] + [
                          '%g' % self.coef_[c, i] if not bf.is_pruned() else
                          'None' for c in range(self.coef_.shape[0])])
            if not bf.is_pruned():
                i += 1
        result += ascii_table(header, data)
        result += '\n'
        result += 'MSE: %.4f, GCV: %.4f, RSQ: %.4f, GRSQ: %.4f' % (
            self.mse_, self.gcv_, self.rsq_, self.grsq_)
        return result

    def summary_feature_importances(self, sort_by=None):
        """
        Returns a string containing a printable summary of the estimated
        feature importances.

        Parameters
        ----------
        sory_by : string, optional
            it refers to a feature importance type name : 'gcv', 'rss'
            or 'nb_subsets'.
            In case it is provided, the features are sorted
            according to the feature importance type corresponding
            to `sort_by`. In case it is not provided, the features are
            not sorted.
        """
        result = ''
        if self._feature_importances_dict:
            max_label_length = max(map(len, self.xlabels_)) + 5
            result += (max_label_length * ' ' +
                       '    '.join(self._feature_importances_dict.keys()) + '\n')
            labels = np.array(self.xlabels_)
            if sort_by:
                if sort_by not in self._feature_importances_dict.keys():
                    raise ValueError('Invalid feature importance type name '
                                     'to sort with : %s, available : %s' % (
                                         sort_by,
                                         self._feature_importances_dict.keys()))
                imp = self._feature_importances_dict[sort_by]
                indices = np.argsort(imp)[::-1]
            else:
                indices = np.arange(len(labels))
            labels = labels[indices]
            for i, label in enumerate(labels):
                result += label + ' ' * (max_label_length - len(label))
                for crit_name, imp in self._feature_importances_dict.items():
                    imp = imp[indices]
                    result += '%.2f' % imp[i] + (len(crit_name) ) * ' '
                result += '\n'
        return result

    def linear_fit(self, X, y=None, sample_weight=None, output_weight=None,
                   missing=None, skip_scrub=False):
        '''
        Solve the linear least squares problem to determine the coefficients
        of the unpruned basis functions.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples and n
            is the number of features The training predictors.  The X parameter
            can be a numpy array, a pandas DataFrame, a patsy
            DesignMatrix, or a tuple of patsy DesignMatrix objects as output
            by patsy.dmatrices.

        y : array-like, optional (default=None), shape = [m, p] where m is the
            number of samples, p the number of outputs.
            The y parameter can be a numpy array, a pandas DataFrame,
            a Patsy DesignMatrix, or can be left as None (default) if X was
            the output of a call to patsy.dmatrices (in which case, X contains
            the response).

        sample_weight : array-like, optional (default=None), shape = [m]
             where m is the number of samples.
             Sample weights for training.  Weights must be greater than or
             equal to zero. Rows with zero weight do not contribute at all.
             Weights are useful when dealing with heteroscedasticity.  In such
             cases, the weight should be proportional to the inverse of the
             (known) variance.

        output_weight : array-like, optional (default=None), shape = [p]
             where p is the number of outputs.
             The total mean squared error (MSE) is a weighted sum of
             mean squared errors (MSE) associated to each output, where
             the weights are given by output_weight.
             Output weights must be greater than or equal
             to zero. Outputs with zero weight do not contribute at all
             to the total mean squared error (MSE).

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a  patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        '''
        # Format data
        if not skip_scrub:
            X, y, sample_weight, output_weight, missing = self._scrub(
                X, y, sample_weight, output_weight, missing)

#         if sample_weight.shape[1]:
#             sample_weight = np.repeat(sample_weight,y.shape[1],axis=1)

        # Solve the linear least squares problem
        self.coef_ = []
        resid_ = []
        total_weight = 0.
        mse0 = 0.
        for i in range(y.shape[1]):

            # Figure out the weight column
            if sample_weight.shape[1] > 1:
                w = sample_weight[:, i]
            else:
                w = sample_weight[:, 0]

            # Transform into basis space
            B = self.transform(X, missing)  # * w[:, None]
            apply_weights_2d(B, w)

            # Compute total weight
            total_weight += np.sum(w)

            # Apply weights to y
            weighted_y = y.copy()
            weighted_y *= np.sqrt(w[:, np.newaxis])

            # Compute the mse0
            mse0 += np.sum((weighted_y[:, i] -
                            np.average(weighted_y[:, i])) ** 2)

            coef, resid = np.linalg.lstsq(B, weighted_y[:, i])[0:2]
            self.coef_.append(coef)
            if not resid:
                resid = np.array(
                    [np.sum((np.dot(B, coef) - weighted_y[:, i]) ** 2)])
            resid_.append(resid)
        resid_ = np.array(resid_)
        self.coef_ = np.array(self.coef_)
        # Compute the final mse, gcv, rsq, and grsq (may be different from the
        # pruning scores if the model has been smoothed)
        self.mse_ = np.sum(resid_) / total_weight
        mse0 = mse0 / total_weight
        self.gcv_ = gcv(self.mse_,
                        coef.shape[0], X.shape[0],
                        self.get_penalty())
        gcv0 = gcv(mse0,
                   1, X.shape[0],
                   self.get_penalty())
        if mse0 != 0.:
            self.rsq_ = 1.0 - (self.mse_ / mse0)
        else:
            self.rsq_ = 1.0
        if gcv0 != 0.:
            self.grsq_ = 1.0 - (self.gcv_ / gcv0)
        else:
            self.grsq_ = 1.0

#
#
#
#         for p, coef in enumerate(self.coef_):
#             mse_p = resid_[p].sum() / float(X.shape[0])
#             gcv_[p] = gcv(mse_p,
#                           coef.shape[0], X.shape[0],
#                           self.get_penalty())
#             self.gcv_ += gcv_[p] * output_weight[p]
#         y_avg = np.average(y, weights=sample_weight if
#                            sample_weight.shape == y.shape else sample_weight.flatten(), axis=0)
#         y_sqr = (y - y_avg[np.newaxis, :]) ** 2
#
#         rsq_ = ((1 - resid_.sum(axis=1) / y_sqr.sum(axis=0)) * output_weight)
#         self.rsq_ = rsq_.sum()
#         gcv0 = np.empty(y.shape[1])
#         for p in range(y.shape[1]):
#             mse0_p = (y_sqr[:, p].sum()) / float(X.shape[0])
#             gcv0[p] = gcv(mse0_p, 1, X.shape[0], self.get_penalty())
#         self.grsq_ = ((1 - (gcv_ / gcv0)) * output_weight).sum()

    def predict(self, X, missing=None, skip_scrub=False):
        '''
        Predict the response based on the input data X.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples and n
            is the number of features
            The training predictors.  The X parameter can be a numpy
            array, a pandas DataFrame, or a patsy DesignMatrix.

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a  patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

       Returns
       -------
            y : array of shape = [m] or [m, p] where m is the number of samples
                and p is the number of outputs
                The predicted values.
        '''
        if not skip_scrub:
            X, missing = self._scrub_x(X, missing)
        B = self.transform(X, missing)
        y = np.dot(B, self.coef_.T)
        if y.shape[1] == 1:
            return y[:, 0]
        else:
            return y

    def predict_deriv(self, X, variables=None, missing=None):
        '''
        Predict the first derivatives of the response based on the input
        data X.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples and n
            is the number of features The training predictors. The X parameter
            can be a numpy array, a pandas DataFrame, or a patsy DesignMatrix.

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        variables : list
            The variables over which derivatives will be computed.  Each column
            in the resulting array corresponds to a variable.  If not
            specified, all variables are used (even if some are not relevant
            to the final model and have derivatives that are identically zero).

       Returns
       -------

       X_deriv : array of shape = [m, n, p] where m is the number of samples, n
                 is the number of features if 'variables' is not specified
                 otherwise it is len(variables) and p is the number of outputs.
                 For each sample, X_deriv represents the first derivative of
                 each response  with respect to each variable.

        '''

        check_is_fitted(self, "basis_")

        if type(variables) in (str, int):
            variables = [variables]
        if variables is None:
            variables_of_interest = list(range(len(self.xlabels_)))
        else:
            variables_of_interest = []
            for var in variables:
                if isinstance(var, int):
                    variables_of_interest.append(var)
                else:
                    variables_of_interest.append(self.xlabels_.index(var))
        X, missing = self._scrub_x(X, missing)
        J = np.zeros(shape=(X.shape[0],
                            len(variables_of_interest),
                            self.coef_.shape[0]))
        b = np.empty(shape=X.shape[0])
        j = np.empty(shape=X.shape[0])
        self.basis_.transform_deriv(
            X, missing, b, j, self.coef_, J, variables_of_interest, True)
        return J

    def score(self, X, y=None, sample_weight=None, output_weight=None,
              missing=None, skip_scrub=False):
        '''
        Calculate the generalized r^2 of the model on data X and y.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features The training predictors.
            The X parameter can be a numpy array, a pandas DataFrame, a patsy
            DesignMatrix, or a tuple of patsy DesignMatrix objects as output
            by patsy.dmatrices.

        y : array-like, optional (default=None), shape = [m, p] where m is the
            number of samples, p the number of outputs.
            The y parameter can be a numpy array, a pandas DataFrame,
            a Patsy DesignMatrix, or can be left as None (default) if X was
            the output of a call to patsy.dmatrices (in which case, X contains
            the response).

        sample_weight : array-like, optional (default=None), shape = [m]
             where m is the number of samples.
             Sample weights for training.  Weights must be greater than or
             equal to zero. Rows with zero weight do not contribute at all.
             Weights are useful when dealing with heteroscedasticity.  In such
             cases, the weight should be proportional to the inverse of the
             (known) variance.

        output_weight : array-like, optional (default=None), shape = [p]
             where p is the number of outputs.
             The total mean squared error (MSE) is a weighted sum of
             mean squared errors (MSE) associated to each output, where
             the weights are given by output_weight.
             Output weights must be greater than or equal
             to zero. Outputs with zero weight do not contribute at all
             to the total mean squared error (MSE).

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        Returns
        -------

        score : float with a maximum value of 1 (it can be negative). The score
                is the generalized r^2 of the model on data X and y, the higher
                the score the better the fit is.

        '''
        check_is_fitted(self, "basis_")
        if not skip_scrub:
            X, y, sample_weight, output_weight, missing = self._scrub(
                X, y, sample_weight, output_weight, missing)
        if sample_weight.shape[1] == 1 and y.shape[1] > 1:
            sample_weight = np.repeat(sample_weight, y.shape[1], axis=1)
        y_hat = self.predict(X)
        if len(y_hat.shape) == 1:
            y_hat = y_hat[:, None]

        residual = y - y_hat
#         total_weight = np.sum(sample_weight)
        mse = np.sum(sample_weight * (residual ** 2))
        y_avg = np.average(y, weights=sample_weight, axis=0)

        mse0 = np.sum(sample_weight * ((y - y_avg) ** 2))
#         mse0 = np.sum(y_sqr * output_weight) / m
        return 1 - (mse / mse0)

    def score_samples(self, X, y=None, missing=None):
        '''

        Calculate sample-wise fit scores.

        Parameters
        ----------

        X : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features The training predictors.
            The X parameter can be a numpy array, a pandas DataFrame, a patsy
            DesignMatrix, or a tuple of patsy DesignMatrix objects as output
            by patsy.dmatrices.

        y : array-like, optional (default=None), shape = [m, p] where m is the
            number of samples, p the number of outputs.
            The y parameter can be a numpy array, a pandas DataFrame,
            a Patsy DesignMatrix, or can be left as None (default) if X was
            the output of a call to patsy.dmatrices (in which case, X contains
            the response).

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a  patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        Returns
        -------

        scores : array of shape=[m, p] of floats with maximum value of 1
                 (it can be negative).
                 The scores represent how good each output of each example is
                 predicted, a perfect score would be 1
                 (the score can be negative).

        '''
        X, y, sample_weight, output_weight, missing = self._scrub(
            X, y, None, None, missing)
        y_hat = self.predict(X, missing=missing)
        residual = 1 - (y - y_hat) ** 2 / y**2
        return residual

    def transform(self, X, missing=None):
        '''
        Transform X into the basis space.  Normally, users will call the
        predict method instead, which both transforms into basis space
        calculates the weighted sum of basis terms to produce a prediction
        of the response.  Users may wish to call transform directly in some
        cases.  For example, users may wish to apply other statistical or
        machine learning algorithms, such as generalized linear regression,
        in basis space.


        Parameters
        ----------
        X : array-like, shape = [m, n] where m is the number of samples and n
            is the number of features
            The training predictors.  The X parameter can be a numpy array, a
            pandas DataFrame, or a patsy DesignMatrix.

        missing : array-like, shape = [m, n] where m is the number of samples
            and n is the number of features.
            The missing parameter can be a numpy array, a pandas DataFrame, or
            a patsy DesignMatrix.  All entries will be interpreted as boolean
            values, with True indicating the corresponding entry in X should be
            interpreted as missing.  If the missing argument not used but the X
            argument is a pandas DataFrame, missing will be inferred from X if
            allow_missing is True.

        Returns
        -------

        B: array of shape [m, nb_terms] where m is the number of samples and
           nb_terms is the number of terms (or basis functions) obtained after
           fitting (which is the number of elements of the attribute `basis_`).
           B represents the values of the basis functions evaluated at each
           sample.
        '''

        check_is_fitted(self, "basis_")
        X, missing = self._scrub_x(X, missing)
        B = np.empty(shape=(X.shape[0], self.basis_.plen()), order='F')
        self.basis_.transform(X, missing, B)
        return B

    def get_penalty(self):
        '''Get the penalty parameter being used.  Default is 3.'''
        if 'penalty' in self.__dict__ and self.penalty is not None:
            return self.penalty
        else:
            return 3.0


class EarthTrace(object):

    def __init__(self, forward_trace, pruning_trace):
        self.forward_trace = forward_trace
        self.pruning_trace = pruning_trace

    def __eq__(self, other):
        return (self.__class__ is other.__class__ and
                self.forward_trace == other.forward_trace and
                self.pruning_trace == other.pruning_trace)

    def __str__(self):
        result = ''
        result += 'Forward Pass\n'
        result += str(self.forward_trace)
        result += '\n'
        result += self.forward_trace.final_str()
        result += '\n\n'
        result += 'Pruning Pass\n'
        result += str(self.pruning_trace)
        result += '\n'
        result += self.pruning_trace.final_str()
        result += '\n'
        return result
